week 7: multilevel models

multilevel adventures

mlm

Often, there are opportunities to cluster your observations – repeated measures, group membership, hierarchies, even different measures for the same particiapnt. Whenever you can cluster, you should!

  • Aggregation is bad
  • Regressions within regressions (ie coefficients as outcomes)
  • Questions at different levels
  • Variance decomposition
  • Learning from other data through pooling/shrinkage
  • Parameters that depend on parameters

multilevel people

data_path = "https://raw.githubusercontent.com/sjweston/uobayes/refs/heads/main/files/data/external_data/mlm.csv"
d <- read.csv(data_path)

rethinking::precis(d)
             mean          sd      5.5%     94.5%    histogram
id    119.1200000 50.69207461 48.000000 209.00000 ▁▂▂▅▇▇▃▅▂▂▂▁
group         NaN          NA        NA        NA             
time    1.8133333  1.75020407  0.000000   6.00000 ▇▃▁▂▁▇▁▁▁▁▁▁
wave    1.8133333  0.80222904  1.000000   3.00000       ▇▇▁▃▁▁
con     0.1910387  0.07341358  0.085244   0.30958   ▁▂▃▇▃▂▁▁▁▁
dan     0.1942627  0.06363019  0.098012   0.29588     ▁▁▅▇▇▃▁▁
d %>% count(id) %>% count(n)
  n nn
1 2 54
2 3 31
3 4  6

What if we wanted to estimate each person’s conscientiousness score? One method would be to simply average scores for each person, but we lose a lot of information that way. Another option would be to treat each person as a group and model scores as a function of group. We can do this using an unpooled model.

\[\begin{align*} \text{con}_i &\sim \text{Normal}(\mu_i,\sigma) \\ \mu_i &= \alpha_{\text{id}[i]} \\ \alpha_j &\sim \text{Normal}(0, 1.5) \text{ for }j=1,...,91 \\ \sigma &\sim \text{Exponential}(1) \end{align*}\]

d$id.f = as.factor(d$id)

m1 <- brm(
  data=d,
  family=gaussian,
  bf(con ~ 0 + a,
     a ~ 0 + id.f,
     nl = TRUE),
  prior = c( prior(normal(0, 1.5), class=b, nlpar=a),
             prior(exponential(1), class=sigma)),
  iter=2000, warmup=1000, chains=4, cores=4, seed=9,
  file=here("files/models/m71.1")
)
m1
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: con ~ 0 + a 
         a ~ 0 + id.f
   Data: d (Number of observations: 225) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
a_id.f6       0.19      0.03     0.14     0.24 1.00     5501     3219
a_id.f29      0.12      0.03     0.06     0.19 1.00     4997     3055
a_id.f34      0.11      0.03     0.04     0.17 1.00     6650     2832
a_id.f36      0.16      0.03     0.09     0.22 1.00     5421     2526
a_id.f37      0.19      0.03     0.13     0.24 1.00     6492     3113
a_id.f48      0.22      0.03     0.16     0.27 1.00     7123     3060
a_id.f53      0.19      0.03     0.12     0.26 1.00     5655     2762
a_id.f54      0.24      0.03     0.19     0.30 1.00     5322     2211
a_id.f58      0.23      0.03     0.18     0.29 1.00     5513     2776
a_id.f61      0.09      0.03     0.02     0.15 1.00     6327     3071
a_id.f66      0.24      0.03     0.19     0.29 1.00     6463     3063
a_id.f67      0.20      0.02     0.16     0.25 1.00     6426     3093
a_id.f69      0.21      0.03     0.14     0.28 1.00     4988     2809
a_id.f71      0.06      0.03     0.00     0.11 1.00     6187     2985
a_id.f74      0.20      0.03     0.15     0.26 1.00     5603     2743
a_id.f75      0.27      0.03     0.20     0.34 1.00     5780     3019
a_id.f76      0.10      0.03     0.04     0.16 1.00     5413     2869
a_id.f78      0.24      0.03     0.19     0.30 1.00     5511     2737
a_id.f79      0.13      0.03     0.06     0.19 1.00     5286     3061
a_id.f80      0.14      0.03     0.07     0.21 1.00     6782     2917
a_id.f81      0.24      0.03     0.19     0.29 1.00     5755     3230
a_id.f82      0.33      0.02     0.28     0.37 1.00     5399     2890
a_id.f85      0.16      0.03     0.09     0.22 1.00     5082     2776
a_id.f86      0.12      0.03     0.05     0.18 1.00     6152     2838
a_id.f87      0.11      0.03     0.05     0.18 1.00     6708     3065
a_id.f89      0.13      0.03     0.08     0.19 1.00     6581     2657
a_id.f91      0.17      0.03     0.12     0.22 1.00     5214     3005
a_id.f92      0.20      0.03     0.15     0.25 1.00     6350     3084
a_id.f93      0.25      0.03     0.20     0.31 1.00     6158     2695
a_id.f94      0.12      0.03     0.06     0.17 1.00     6642     3086
a_id.f96      0.23      0.03     0.17     0.30 1.00     4951     2897
a_id.f97      0.25      0.02     0.20     0.29 1.00     6265     3101
a_id.f98      0.14      0.02     0.09     0.19 1.00     5142     2795
a_id.f99      0.18      0.02     0.14     0.23 1.00     5826     2850
a_id.f101     0.25      0.03     0.20     0.30 1.00     6668     2935
a_id.f102     0.19      0.02     0.14     0.23 1.00     5504     2855
a_id.f103     0.15      0.03     0.09     0.22 1.00     5462     2604
a_id.f104     0.17      0.03     0.10     0.24 1.00     6801     2667
a_id.f105     0.16      0.03     0.11     0.21 1.00     5765     3038
a_id.f106     0.20      0.03     0.14     0.25 1.00     6377     3039
a_id.f110     0.18      0.03     0.13     0.24 1.00     5455     3009
a_id.f112     0.09      0.03     0.04     0.14 1.00     5278     2985
a_id.f114     0.13      0.03     0.08     0.18 1.00     5166     2492
a_id.f115     0.14      0.03     0.08     0.19 1.00     6953     2513
a_id.f116     0.14      0.03     0.09     0.20 1.00     5657     3057
a_id.f120     0.35      0.03     0.29     0.42 1.00     5768     3102
a_id.f122     0.17      0.03     0.12     0.22 1.00     5392     2892
a_id.f125     0.20      0.03     0.14     0.25 1.00     6633     2693
a_id.f127     0.23      0.03     0.17     0.28 1.00     5495     3188
a_id.f129     0.13      0.03     0.06     0.19 1.00     5429     2923
a_id.f135     0.30      0.03     0.24     0.35 1.00     5426     2677
a_id.f136     0.29      0.03     0.24     0.34 1.00     5643     3113
a_id.f137     0.29      0.03     0.23     0.36 1.00     5445     2817
a_id.f140     0.13      0.03     0.06     0.19 1.00     6236     3438
a_id.f141     0.19      0.03     0.12     0.26 1.00     5091     3025
a_id.f142     0.18      0.03     0.11     0.24 1.00     5232     2652
a_id.f143     0.31      0.03     0.25     0.38 1.00     5705     2873
a_id.f144     0.17      0.03     0.11     0.24 1.00     5553     2678
a_id.f146     0.21      0.03     0.14     0.28 1.00     5345     2768
a_id.f149     0.23      0.03     0.16     0.29 1.00     5531     2612
a_id.f150     0.20      0.03     0.13     0.27 1.00     6163     2830
a_id.f152     0.22      0.03     0.16     0.29 1.00     5220     2066
a_id.f153     0.15      0.03     0.08     0.21 1.00     5373     2711
a_id.f155     0.14      0.03     0.08     0.21 1.00     6550     2174
a_id.f156     0.14      0.03     0.08     0.19 1.00     6065     2818
a_id.f159     0.12      0.03     0.06     0.19 1.00     4996     3074
a_id.f160     0.18      0.03     0.11     0.24 1.00     5892     3125
a_id.f162     0.39      0.03     0.34     0.45 1.00     7084     2544
a_id.f163     0.13      0.03     0.07     0.20 1.00     5687     2752
a_id.f165     0.20      0.03     0.13     0.26 1.00     6052     2742
a_id.f167     0.13      0.03     0.07     0.20 1.00     5675     2796
a_id.f169     0.24      0.03     0.18     0.31 1.00     5811     2772
a_id.f171     0.23      0.03     0.16     0.29 1.00     7163     3017
a_id.f174     0.26      0.03     0.19     0.33 1.00     6936     3011
a_id.f182     0.19      0.03     0.12     0.26 1.00     5793     3007
a_id.f187     0.06      0.03    -0.00     0.13 1.00     5937     2874
a_id.f189     0.22      0.03     0.15     0.29 1.00     6351     2464
a_id.f190     0.13      0.03     0.06     0.19 1.00     6706     2867
a_id.f193     0.20      0.03     0.14     0.27 1.00     5815     2656
a_id.f194     0.12      0.03     0.05     0.19 1.00     6378     2715
a_id.f201     0.19      0.03     0.12     0.25 1.00     5999     3272
a_id.f204     0.29      0.03     0.22     0.35 1.00     5491     2828
a_id.f205     0.19      0.03     0.13     0.26 1.00     5635     2555
a_id.f208     0.20      0.03     0.14     0.27 1.00     5887     3309
a_id.f209     0.21      0.03     0.15     0.28 1.00     6057     3225
a_id.f211     0.14      0.03     0.07     0.20 1.00     5253     2976
a_id.f214     0.28      0.03     0.21     0.34 1.00     5544     2671
a_id.f219     0.23      0.03     0.16     0.30 1.00     5851     2773
a_id.f222     0.14      0.03     0.07     0.20 1.00     5835     2357
a_id.f223     0.17      0.03     0.10     0.23 1.00     5549     3003
a_id.f229     0.14      0.03     0.07     0.21 1.00     4735     2709

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.05      0.00     0.04     0.05 1.00     2112     2764

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

This is inefficient, in that the model treat each person as entirely separate. Let’s try a partial pooling model.

\[\begin{align*} \text{con}_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{\text{id}[i]} \\ \alpha_j &\sim \text{Normal}(\bar{\alpha}, \sigma_{\alpha}) \text{ for j in 1...91}\\ \bar{\alpha} &\sim \text{Normal}(0, 1.5)\\ \sigma_{\alpha} &\sim \text{Exponential}(1) \\ \sigma \sim &\text{Exponential}(1) \\ \end{align*}\]

m2 <- brm(
  data=d,
  family=gaussian,
  con ~ 1 + (1 | id), 
  prior = c( prior(normal(0, 1.5), class=Intercept),
             prior(exponential(1), class=sd),
             prior(exponential(1), class=sigma)),
  iter=2000, warmup=1000, chains=4, cores=4, seed=9,
  file=here("files/models/m71.2")
)
m2
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: con ~ 1 + (1 | id) 
   Data: d (Number of observations: 225) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~id (Number of levels: 91) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.06      0.01     0.05     0.07 1.00     1484     2271

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     0.19      0.01     0.18     0.20 1.00     2156     2901

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.05      0.00     0.04     0.05 1.00     2526     2383

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

how many parameters does each model have?

get_variables(m1) 
  [1] "b_a_id.f6"     "b_a_id.f29"    "b_a_id.f34"    "b_a_id.f36"   
  [5] "b_a_id.f37"    "b_a_id.f48"    "b_a_id.f53"    "b_a_id.f54"   
  [9] "b_a_id.f58"    "b_a_id.f61"    "b_a_id.f66"    "b_a_id.f67"   
 [13] "b_a_id.f69"    "b_a_id.f71"    "b_a_id.f74"    "b_a_id.f75"   
 [17] "b_a_id.f76"    "b_a_id.f78"    "b_a_id.f79"    "b_a_id.f80"   
 [21] "b_a_id.f81"    "b_a_id.f82"    "b_a_id.f85"    "b_a_id.f86"   
 [25] "b_a_id.f87"    "b_a_id.f89"    "b_a_id.f91"    "b_a_id.f92"   
 [29] "b_a_id.f93"    "b_a_id.f94"    "b_a_id.f96"    "b_a_id.f97"   
 [33] "b_a_id.f98"    "b_a_id.f99"    "b_a_id.f101"   "b_a_id.f102"  
 [37] "b_a_id.f103"   "b_a_id.f104"   "b_a_id.f105"   "b_a_id.f106"  
 [41] "b_a_id.f110"   "b_a_id.f112"   "b_a_id.f114"   "b_a_id.f115"  
 [45] "b_a_id.f116"   "b_a_id.f120"   "b_a_id.f122"   "b_a_id.f125"  
 [49] "b_a_id.f127"   "b_a_id.f129"   "b_a_id.f135"   "b_a_id.f136"  
 [53] "b_a_id.f137"   "b_a_id.f140"   "b_a_id.f141"   "b_a_id.f142"  
 [57] "b_a_id.f143"   "b_a_id.f144"   "b_a_id.f146"   "b_a_id.f149"  
 [61] "b_a_id.f150"   "b_a_id.f152"   "b_a_id.f153"   "b_a_id.f155"  
 [65] "b_a_id.f156"   "b_a_id.f159"   "b_a_id.f160"   "b_a_id.f162"  
 [69] "b_a_id.f163"   "b_a_id.f165"   "b_a_id.f167"   "b_a_id.f169"  
 [73] "b_a_id.f171"   "b_a_id.f174"   "b_a_id.f182"   "b_a_id.f187"  
 [77] "b_a_id.f189"   "b_a_id.f190"   "b_a_id.f193"   "b_a_id.f194"  
 [81] "b_a_id.f201"   "b_a_id.f204"   "b_a_id.f205"   "b_a_id.f208"  
 [85] "b_a_id.f209"   "b_a_id.f211"   "b_a_id.f214"   "b_a_id.f219"  
 [89] "b_a_id.f222"   "b_a_id.f223"   "b_a_id.f229"   "sigma"        
 [93] "lprior"        "lp__"          "accept_stat__" "stepsize__"   
 [97] "treedepth__"   "n_leapfrog__"  "divergent__"   "energy__"     

how many parameters does each model have?

get_variables(m1) %>% length()
[1] 100
get_variables(m2) %>% length()
[1] 103

What additional parameters?

m1 has a unique intercept for each participant and a standard deviation of scores (1 \(\sigma\)).

m2 is estimating all of that plus a grand mean intercept and the variability of means (\(\sigma_M\)).

(what’s the extra one? brms lists the intercept twice. *shrug emoji*)

And yet!

m1 <- add_criterion(m1, criterion = "loo")
m2 <- add_criterion(m2, criterion = "loo")

loo_compare(m1, m2) %>% print(simplify=F)
   elpd_diff se_diff elpd_loo se_elpd_loo p_loo  se_p_loo looic  se_looic
m2    0.0       0.0   325.1     11.3        65.7    5.2   -650.2   22.6  
m1  -14.5       5.8   310.6      9.4        82.9    5.3   -621.1   18.9  

Let’s visualize the differences in these.

Code
nd1 = distinct(d, id.f)
post1 = epred_draws(m1, nd1)
nd2 = distinct(d, id)
post2 = epred_draws(m2, nd2)
p1 = post1 %>% 
  ggplot( aes(y=.epred, x=id.f) ) +
  stat_gradientinterval() +
  scale_x_discrete(labels=NULL, breaks=NULL) +
  labs(x="id", y="con", title = "no pooling")

p2 = post2 %>% 
  mutate(id=as.factor(id)) %>% 
  ggplot( aes(y=.epred, x=id) ) +
  stat_gradientinterval() +
  scale_x_discrete(labels=NULL, breaks=NULL) +
  labs(x="id", y="con", title = "partial pooling")

p1 / p2
Code
means1 = post1 %>% 
  mean_qi(.epred)
means2 = post2 %>% 
  mean_qi(.epred) %>% 
  mutate(id=as.factor(id))

means1 %>% 
  ggplot( aes(x=id.f, y=.epred)) +
  geom_hline( aes(yintercept=mean(.epred)),
              linetype="dashed") +
  geom_point( aes(color="no pooling") ) +
  geom_point( aes(x=id, color="partial pooling"),
              data=means2,
              size=2) +
  scale_color_manual( values=c("#e07a5f", "#1c5253") ) +
  scale_x_discrete(breaks=NULL) +
  labs(x="id", y="con")+
  theme(legend.position = "top")

extracting estimates

Yikes!

as_draws_df(m2) %>% round(2)
# A draws_df: 1000 iterations, 4 chains, and 97 variables
   b_Intercept sd_id__Intercept sigma Intercept r_id[6,Intercept]
1         0.19             0.05  0.05      0.19              0.00
2         0.19             0.06  0.04      0.19             -0.01
3         0.18             0.05  0.05      0.18              0.00
4         0.18             0.05  0.05      0.18              0.01
5         0.19             0.07  0.05      0.19              0.00
6         0.19             0.06  0.05      0.19             -0.02
7         0.19             0.05  0.05      0.19             -0.02
8         0.19             0.05  0.05      0.19             -0.03
9         0.20             0.05  0.05      0.20              0.02
10        0.19             0.05  0.05      0.19             -0.05
   r_id[29,Intercept] r_id[34,Intercept] r_id[36,Intercept]
1               -0.07              -0.05               0.03
2               -0.03              -0.07              -0.03
3               -0.04              -0.07               0.03
4               -0.06              -0.06               0.01
5               -0.05              -0.12              -0.05
6               -0.06              -0.06               0.01
7               -0.03              -0.05              -0.03
8               -0.02              -0.06              -0.04
9               -0.09              -0.05               0.01
10               0.01              -0.07              -0.05
# ... with 3990 more draws, and 89 more variables
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}

extracting estimates: a better way

These estimates are deviations from the weighted average. Also note the grouped structure of this object.

#
m2 %>% spread_draws(r_id[id, term])
1
id and term become the names of the columns in our resulting object. You can use whatever strings you want here.
# A tibble: 364,000 × 6
# Groups:   id, term [91]
      id term           r_id .chain .iteration .draw
   <int> <chr>         <dbl>  <int>      <int> <int>
 1     6 Intercept -0.000943      1          1     1
 2     6 Intercept -0.0100        1          2     2
 3     6 Intercept  0.00390       1          3     3
 4     6 Intercept  0.00956       1          4     4
 5     6 Intercept -0.00155       1          5     5
 6     6 Intercept -0.0183        1          6     6
 7     6 Intercept -0.0190        1          7     7
 8     6 Intercept -0.0314        1          8     8
 9     6 Intercept  0.0203        1          9     9
10     6 Intercept -0.0505        1         10    10
# ℹ 363,990 more rows

extracting estimates: a better way

These estimates are deviations from the weighted average. Also note the grouped structure of this object.

#
m2 %>% spread_draws(r_id[id, term]) %>% 
  mean_qi()
# A tibble: 91 × 8
      id term           r_id   .lower   .upper .width .point .interval
   <int> <chr>         <dbl>    <dbl>    <dbl>  <dbl> <chr>  <chr>    
 1     6 Intercept  0.000264 -0.0488   0.0493    0.95 mean   qi       
 2    29 Intercept -0.0481   -0.107    0.00842   0.95 mean   qi       
 3    34 Intercept -0.0618   -0.119   -0.00632   0.95 mean   qi       
 4    36 Intercept -0.0219   -0.0796   0.0352    0.95 mean   qi       
 5    37 Intercept -0.00301  -0.0530   0.0474    0.95 mean   qi       
 6    48 Intercept  0.0235   -0.0242   0.0705    0.95 mean   qi       
 7    53 Intercept  0.000586 -0.0536   0.0556    0.95 mean   qi       
 8    54 Intercept  0.0451   -0.00463  0.0930    0.95 mean   qi       
 9    58 Intercept  0.0368   -0.0139   0.0848    0.95 mean   qi       
10    61 Intercept -0.0742   -0.134   -0.0139    0.95 mean   qi       
# ℹ 81 more rows

If you want to get the real means (not the deviations from the grand mean, just grab the estimate of the grand mean.)

m2 %>% spread_draws(b_Intercept, r_id[id, term]) %>% 
  mutate(r_id = b_Intercept + r_id) %>% 
  mean_qi(r_id)
# A tibble: 91 × 8
      id term       r_id .lower .upper .width .point .interval
   <int> <chr>     <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
 1     6 Intercept 0.190 0.142   0.237   0.95 mean   qi       
 2    29 Intercept 0.141 0.0834  0.197   0.95 mean   qi       
 3    34 Intercept 0.127 0.0713  0.183   0.95 mean   qi       
 4    36 Intercept 0.167 0.110   0.224   0.95 mean   qi       
 5    37 Intercept 0.186 0.138   0.235   0.95 mean   qi       
 6    48 Intercept 0.213 0.166   0.259   0.95 mean   qi       
 7    53 Intercept 0.190 0.137   0.244   0.95 mean   qi       
 8    54 Intercept 0.234 0.186   0.282   0.95 mean   qi       
 9    58 Intercept 0.226 0.176   0.275   0.95 mean   qi       
10    61 Intercept 0.115 0.0577  0.175   0.95 mean   qi       
# ℹ 81 more rows

more than one type of cluster

McElreath doesn’t cover this in his video lecture, but this is from the textbook and worth discussing.

data(chimpanzees, package="rethinking")
d <- chimpanzees
rethinking::precis(d)
                   mean         sd  5.5%  94.5%    histogram
actor         4.0000000  2.0019871 1.000  7.000 ▇▇▁▇▁▇▁▇▁▇▁▇
recipient     5.0000000  2.0039801 2.000  8.000 ▇▇▁▇▁▇▁▇▁▇▁▇
condition     0.5000000  0.5004968 0.000  1.000   ▇▁▁▁▁▁▁▁▁▇
block         3.5000000  1.7095219 1.000  6.000   ▇▇▁▇▁▇▁▇▁▇
trial        36.5000000 20.8032533 4.665 68.335     ▇▇▇▇▇▇▇▁
prosoc_left   0.5000000  0.5004968 0.000  1.000   ▇▁▁▁▁▁▁▁▁▇
chose_prosoc  0.5674603  0.4959204 0.000  1.000   ▅▁▁▁▁▁▁▁▁▇
pulled_left   0.5793651  0.4941515 0.000  1.000   ▅▁▁▁▁▁▁▁▁▇
unique(d$actor)
[1] 1 2 3 4 5 6 7
unique(d$block)
[1] 1 2 3 4 5 6
unique(d$prosoc_left)
[1] 0 1
unique(d$condition)
[1] 0 1

We could model the interaction between condition (presence/absence of another animal) and option (which side is prosocial), but it is more difficult to assign sensible priors to interaction effects. Another option, because we’re working with categorical variables, is to turn our 2x2 into one variable with 4 levels.

d$treatment <- factor(1 + d$prosoc_left + 2*d$condition)
d %>% count(treatment, prosoc_left, condition)
  treatment prosoc_left condition   n
1         1           0         0 126
2         2           1         0 126
3         3           0         1 126
4         4           1         1 126

In this experiment, each pull is within a cluster of pulls belonging to an individual chimpanzee. But each pull is also within an experimental block, which represents a collection of observations that happened on the same day. So each observed pull belongs to both an actor (1 to 7) and a block (1 to 6). There may be unique intercepts for each actor as well as for each block.

Mathematical model:

\[\begin{align*} L_i &\sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) &= \bar{\alpha} + \alpha_{\text{ACTOR[i]}} + \bar{\gamma} + \gamma_{\text{BLOCK[i]}} + \beta_{\text{TREATMENT[i]}} \\ \beta_j &\sim \text{Normal}(0, 0.5) \text{ , for }j=1..4\\ \alpha_j &\sim \text{Normal}(0, \sigma_{\alpha}) \text{ , for }j=1..7\\ \gamma_j &\sim \text{Normal}(0, \sigma_{\gamma}) \text{ , for }j=1..7\\ \bar{\alpha} &\sim \text{Normal}(0, 1.5) \\ \bar{\gamma} &\sim \text{Normal}(0, 1.5) \\ \sigma_{\alpha} &\sim \text{Exponential}(1) \\ \sigma_{\gamma} &\sim \text{Exponential}(1) \\ \end{align*}\]

m3 <- 
  brm(
    family = bernoulli,
    data = d, 
    bf(
      pulled_left ~ a + b, 
      a ~ 1 + (1 | actor) + (1 | block), 
      b ~ 0 + treatment, 
      nl = TRUE),
    prior = c(prior(normal(0, 0.5), nlpar = b),
              prior(normal(0, 1.5), class = b, coef = Intercept, nlpar = a),
              prior(exponential(1), class = sd, group = actor, nlpar = a),
              prior(exponential(1), class = sd, group = block, nlpar = a)),
  chains=4, cores=4, iter=2000, warmup=1000,
  seed = 1,
  file = here("files/models/71.3")
  )
m3
 Family: bernoulli 
  Links: mu = logit 
Formula: pulled_left ~ a + b 
         a ~ 1 + (1 | actor) + (1 | block)
         b ~ 0 + treatment
   Data: d (Number of observations: 504) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~actor (Number of levels: 7) 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(a_Intercept)     2.04      0.66     1.12     3.63 1.01     1424     2076

~block (Number of levels: 6) 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(a_Intercept)     0.21      0.17     0.01     0.63 1.00     1587     1660

Regression Coefficients:
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
a_Intercept      0.58      0.71    -0.76     2.01 1.00     1136     1789
b_treatment1    -0.13      0.30    -0.71     0.46 1.00     2102     2948
b_treatment2     0.40      0.30    -0.20     0.99 1.00     1820     2576
b_treatment3    -0.48      0.30    -1.06     0.11 1.00     1937     2509
b_treatment4     0.28      0.30    -0.29     0.89 1.00     1910     2502

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
posterior_summary(m3)
                             Estimate Est.Error          Q2.5        Q97.5
b_a_Intercept            5.835058e-01 0.7051870 -7.574382e-01    2.0133561
b_b_treatment1          -1.285035e-01 0.3000948 -7.111530e-01    0.4608178
b_b_treatment2           3.968713e-01 0.2977210 -2.006941e-01    0.9864687
b_b_treatment3          -4.770727e-01 0.3000809 -1.057934e+00    0.1067693
b_b_treatment4           2.815642e-01 0.2984428 -2.944552e-01    0.8881553
sd_actor__a_Intercept    2.036780e+00 0.6561266  1.116480e+00    3.6306757
sd_block__a_Intercept    2.088196e-01 0.1743707  8.311932e-03    0.6295848
r_actor__a[1,Intercept] -9.391636e-01 0.7061219 -2.360315e+00    0.3928665
r_actor__a[2,Intercept]  4.108310e+00 1.3648897  2.040292e+00    7.2360335
r_actor__a[3,Intercept] -1.245105e+00 0.7113142 -2.669128e+00    0.1290459
r_actor__a[4,Intercept] -1.245720e+00 0.7152689 -2.676802e+00    0.1138210
r_actor__a[5,Intercept] -9.385328e-01 0.7167749 -2.405589e+00    0.4292875
r_actor__a[6,Intercept]  1.203465e-03 0.7156267 -1.441203e+00    1.3720687
r_actor__a[7,Intercept]  1.526423e+00 0.7600257  4.334721e-02    3.0016306
r_block__a[1,Intercept] -1.664840e-01 0.2169600 -7.073766e-01    0.1194874
r_block__a[2,Intercept]  3.467096e-02 0.1698465 -2.979211e-01    0.4263189
r_block__a[3,Intercept]  4.830452e-02 0.1783436 -2.839705e-01    0.4701920
r_block__a[4,Intercept]  1.138782e-02 0.1732665 -3.519582e-01    0.3951538
r_block__a[5,Intercept] -2.934102e-02 0.1689017 -4.049284e-01    0.3082666
r_block__a[6,Intercept]  1.078886e-01 0.1934904 -1.977606e-01    0.5785745
lprior                  -6.336549e+00 1.2258524 -9.231991e+00   -4.5017996
lp__                    -2.866412e+02 3.7297778 -2.947833e+02 -280.1872116
m3 %>% 
  mcmc_plot(variable = c("^r_", "^b_", "^sd_"), regex = T) +
  theme(axis.text.y = element_text(hjust = 0))

Zooming in on just the actor and block effects. (Remember, these are differences from the weighted average.)

m3 %>% 
  mcmc_plot(variable = c("^r_"), regex = T) +
  theme(axis.text.y = element_text(hjust = 0))
Code
as_draws_df(m3) %>% 
  select(starts_with("sd")) %>% 
  pivot_longer(everything()) %>% 
  ggplot(aes(x = value, fill = name)) +
  geom_density(linewidth = 0, alpha = 3/4, adjust = 2/3, show.legend = F) +
  annotate(geom = "text", x = 0.67, y = 2, label = "block", color = "#5e8485") +
  annotate(geom = "text", x = 2.725, y = 0.5, label = "actor", color = "#0f393a") +
  scale_fill_manual(values = c("#0f393a", "#5e8485")) +
  scale_y_continuous(NULL, breaks = NULL) +
  ggtitle(expression(sigma["group"])) +
  coord_cartesian(xlim = c(0, 4))

exercise

Return to the data(Trolley) from an earlier lecture. Define and fit a varying intercepts model for these data, with responses clustered within participants. Include action, intention, and contact. Compare the varying-intercepts model and the model that ignores individuals using both some method of cross-validation.

solution

data(Trolley, package="rethinking")

# fit model without varying intercepts
m_simple <- brm(
  data = Trolley,
  family = cumulative, 
  response ~ 1 + action + intention + contact, 
  prior = c( prior(normal(0, 1.5), class = Intercept) ),
  iter=2000, warmup=1000, cores=4, chains=4,
  file=here("files/data/generated_data/m71.e1")
)

# fit model with varying intercepts
m_varying <- brm(
  data = Trolley,
  family = cumulative, 
  response ~ 1 + action + intention + contact + (1|id), 
  prior = c( prior(normal(0, 1.5), class = Intercept),
             prior(normal(0, 0.5), class = b),
             prior(exponential(1), class = sd)),
  iter=2000, warmup=1000, cores=4, chains=4,
  file=here("files/data/generated_data/m71.e2")
)

solution

# compare models using WAIC cross-validation
m_simple  <- add_criterion(m_simple , "loo")
m_varying <- add_criterion(m_varying, "loo")

loo_compare(m_simple, m_varying, criterion = "loo") %>% 
  print(simplify=F)
          elpd_diff se_diff  elpd_loo se_elpd_loo p_loo    se_p_loo looic   
m_varying      0.0       0.0 -15669.2     88.7       354.2      4.6  31338.4
m_simple   -2876.0      86.2 -18545.1     38.1         9.2      0.0  37090.3
          se_looic
m_varying    177.5
m_simple      76.2
pp_check(m_simple, ndraws = 5, type="hist") +
  ggtitle("Simple Model")
pp_check(m_varying, ndraws = 5, type="hist") +
  ggtitle("Varying Intercepts Model")

predictions

Posterior predictions in multilevel models are a bit more complicated than single-level, because the question arises: predictions for the same clusters or predictions for new clusters?

In other words, do you want to know more about the chimps you collected data on, or new chimps? Let’s talk about both.

predictions for chimps in our sample

Recall that the function fitted() give predictions. Using the argument re_formula = NULL specifies that we want to include our group-level estimates in our estimations.

labels <- c("R/N", "L/N", "R/P", "L/P")

nd <- distinct(d, treatment, actor) %>% 
  mutate(block=1)

f <- fitted(m3,newdata = nd, re_formula = NULL) %>% 
  data.frame() %>% 
  bind_cols(nd) %>% 
  mutate(treatment = factor(treatment, labels = labels))
Code
f %>% 
  ggplot( aes(x=treatment, y=Estimate, group=1) ) +
  geom_ribbon(aes( ymin=Q2.5, ymax=Q97.5 ), 
              fill = "#0f393a",
              alpha=.3) +
  geom_line(color="#0f393a") +
  scale_y_continuous(limits=c(0,1)) +
  facet_wrap(~actor)
Code
# observed p
obs = d %>% 
  filter(block==1) %>% 
  group_by(actor, treatment) %>% 
  summarise(p = mean(pulled_left), .groups = "drop") %>% 
  mutate(treatment = factor(treatment, labels = labels))


f %>% 
  ggplot( aes(x=treatment, y=Estimate, group=1) ) +
  geom_ribbon(aes( ymin=Q2.5, ymax=Q97.5 ), 
              fill = "#0f393a",
              alpha=.3) +
  geom_point( aes(y=p), 
              data=obs, 
              shape=1) +
  geom_line(color="#0f393a") +
  facet_wrap(~actor)

We can add in the observed probabilities.

predictions for new chimps

Even here, we have some choice. Let’s start by predicting scores for the average chimp. We can use the same code as before, but set re_formula to NA.

labels <- c("R/N", "L/N", "R/P", "L/P")

nd <- distinct(d, treatment) %>% 
  mutate(block=1)

f_avg <- fitted(m3,newdata = nd, re_formula = NA) %>% 
  data.frame() %>% 
  bind_cols(nd) %>% 
  mutate(treatment = factor(treatment, labels = labels))

We’ll add the average chimp to the plot.

Code
f %>% 
  ggplot( aes(x=treatment, y=Estimate, group=1) ) +
  geom_ribbon(aes( ymin=Q2.5, ymax=Q97.5 ), 
              fill = "#0f393a",
              alpha=.3) +
  geom_line(color="#0f393a") +
  geom_ribbon(aes( ymin=Q2.5, ymax=Q97.5 ), 
              data=f_avg,
              fill = "#e07a5f",
              alpha=.3) +
  geom_line(color="#e07a5f", data=f_avg) +
  scale_y_continuous(limits=c(0,1)) +
  facet_wrap(~actor)

But the average chimp is only one possible chimp we could encounter. Let’s simulate 100 possible chimps.

Code
post = as_draws_df(m3)

post %>% 
  slice_sample(n=100) %>% 
  # simulate chimps
  mutate(a_sim = rnorm(n(), mean = b_a_Intercept, sd = sd_actor__a_Intercept)) %>% 
  pivot_longer(b_b_treatment1:b_b_treatment4) %>% 
  mutate(fitted = inv_logit_scaled(a_sim + value)) %>% 
  mutate(treatment = factor(str_remove(name, "b_b_treatment"),
                            labels = labels)) %>%
  ggplot(aes(x = treatment, y = fitted, group = .draw)) +
  geom_line(alpha = 1/2, color = "#e07a5f") +
  coord_cartesian(ylim = 0:1) 

exercise

Returning to the Trolley data and the varying intercept model, get predictions for…

  1. a subset of 3 participants in the dataset.
  2. the average participant.
  3. 2 new participants.

Hint: don’t forget that the model uses a link function. You may need to play with arguments or fiddle around with the outputs of your functions.

solution

3 participants

Code
part3 = sample( unique(Trolley$id) , size=3, replace=F )
nd <- distinct(Trolley, action, intention, contact, id) %>% 
  filter(id %in% part3)

f <- fitted(m_varying, newdata = nd, scale = "response") %>% 
  data.frame() %>% 
  bind_cols(nd) 

f %>% 
  pivot_longer(-c(action:id),
               names_sep = "\\.{3}",
               names_to = c("stat", "response")) %>% 
  pivot_wider(names_from = stat, values_from = value) %>% 
  mutate(response = str_sub(response, 1, 1)) %>% 
  ggplot(aes(x=response, y=Estimate.P.Y, fill=as.factor(intention))) +
  geom_bar(stat="identity", position="dodge") +
  labs(y="p") +
  facet_grid(action+contact~id) +
  theme(legend.position = "bottom")

solution

The average participant

Code
nd <- distinct(Trolley, action, intention, contact) 

f <- fitted(m_varying, newdata = nd, scale = "response", 
            re_formula = NA) %>% 
  data.frame() %>% 
  bind_cols(nd) 

f %>% 
  pivot_longer(-c(action:contact),
               names_sep = "\\.{3}",
               names_to = c("stat", "response")) %>% 
  pivot_wider(names_from = stat, values_from = value) %>% 
  mutate(response = str_sub(response, 1, 1)) %>% 
  ggplot(aes(x=response, y=Estimate.P.Y, fill=as.factor(intention))) +
  geom_bar(stat="identity", position="dodge") +
  labs(y="p") +
  facet_grid(action~contact) +
  theme(legend.position = "bottom")

Two new participants

Code
# create data for 2 new participants
nd <- distinct(Trolley, action, intention, contact) %>%
  slice(rep(1:n(), times = 2)) %>%
  mutate(id = rep(c("New1", "New2"), each = n()/2))

# get predictions including random effects
f <- fitted(m_varying, newdata = nd, 
            scale = "response", allow_new_levels=T) %>% 
  data.frame() %>% 
  bind_cols(nd) 

# plot
f %>% 
  pivot_longer(-c(action:id),
               names_sep = "\\.{3}",
               names_to = c("stat", "response")) %>% 
  pivot_wider(names_from = stat, values_from = value) %>% 
  mutate(response = str_sub(response, 1, 1)) %>% 
  ggplot(aes(x=response, y=Estimate.P.Y, fill=as.factor(intention))) +
  geom_bar(stat="identity", position="dodge") +
  labs(y="p", fill="intention") +
  facet_grid(action+contact~id) +
  theme(legend.position = "bottom")